function x_hat = VAMP(y, A, lambda)
    [M,N] = size(A);
    assert(M == length(y) && isvector(y));
    
    A_H = A';
    x_hat = zeros(N,1); 
    v = zeros(M,1); d = 0; dc = 0;
    t = 0;
    x_hat_prev = ones(N,1);
    
    while t <= 200 && (t == 0 || (norm(x_hat_prev-x_hat)/norm(x_hat) > 1e-5))
        x_hat_prev = x_hat;
        v = y-A*x_hat+(d/M)*v+(dc/M)*conj(v);
        r = x_hat+A_H*v;
        sigma2 = norm(v)^2/M;
        [x_hat, d, dc] = shrink(r, lambda, sigma2);
        % disp(norm(x_hat));
        t = t + 1;
    end
    
end

function [x_hat, d, dc] = shrink(r, lambda, sigma2)
    sigma = sqrt(sigma2);
    b = abs(r)>lambda*sigma;
    x_hat = (r-lambda*sigma*exp(1j*angle(r))).*b;
    d = sum((1-lambda*sigma./(2*abs(r))).*b);
    dc = -sum(lambda*sigma*(r.^2)./(2*abs(r).^3).*b);
end

